import torch

from networks.skeleton_blocks import *

def init_weight(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
        nn.init.xavier_normal_(m.weight)
        # m.bias.data.fill_(0.01)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)


def reparameterize(mu, logvar):
    s_var = logvar.mul(0.5).exp_()
    eps = s_var.data.new(s_var.size()).normal_()
    return eps.mul(s_var).add_(mu)

class StyleContentEncoder(nn.Module):
    def __init__(self, n_down, topology, kernel_size, action_dim=0, style_dim=0, max_channel_per_joint=48):
        super().__init__()
        # mid_channels = 12 -> 24 -> 48 -> 48

        self.topologies = [topology]
        self.mid_layers = nn.ModuleList()
        self.channel_base = [20]
        self.joint_num = [len(topology)]
        self.channel_list = []
        self.pooling_lists = []

        kernel_size = kernel_size
        padding = (kernel_size - 1) //2

        # for i in range(n_layers):
        #     self
        # To Middle Point
        for i in range(n_down):
            seq = []
            neighbour_list = find_neighbor_joint(self.topologies[-1], 2)
            in_channels = self.channel_base[i] * self.joint_num[-1]
            out_channels_per_joint = min(max_channel_per_joint, self.channel_base[-1]*2)
            out_channels = out_channels_per_joint * self.joint_num[-1]
            self.channel_base.append(out_channels_per_joint)
            if i == 0: self.channel_list.append(in_channels)
            self.channel_list.append(out_channels)
            # print(in_channels, out_channels_per_joint, out_channels, self.joint_num[-1])
            # Add Convolution Layer. Reduce time length
            seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels, joint_num=self.joint_num[-1],
                                    kernel_size=kernel_size, stride=2, padding=padding, padding_mode="reflection", bias=True))
            # Add Pooling Layer. Reduce skeleton joint
            pooler = SkeletonPoolJoint(self.topologies[-1], channels_per_joint=out_channels_per_joint)
            seq.append(pooler)
            seq.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
            self.mid_layers.append(nn.Sequential(*seq))

            if len(pooler.new_topology) < self.joint_num[-1]:
                self.topologies.append(pooler.new_topology)
                self.pooling_lists.append(pooler.pooling_list)
                self.joint_num.append(len(pooler.new_topology))

        # To Spatial Space
        sp_conv = []
        neighbour_list = find_neighbor_joint(self.topologies[-1], 2)
        # May conditioned on action category
        in_channels = (self.channel_base[-1] + action_dim) * self.joint_num[-1]
        # print(self.channel_list)
        # print(self.joint_num[-1],  in_channels)
        out_channels = self.channel_base[-1] * self.joint_num[-1]
        # Keep the time length, keep skeleton structure
        sp_conv.append(SkeletonConv(neighbour_list, in_channels, out_channels, joint_num=self.joint_num[-1],
                                      kernel_size=kernel_size, stride=1, padding=padding, padding_mode="reflection", bias=True))
        sp_conv.append(nn.LeakyReLU(0.2, inplace=True))
        sp_conv.append(nn.InstanceNorm1d(out_channels, affine=False))

        in_channels = out_channels
        out_channels = in_channels // 2
        sp_conv.append(SkeletonConv(neighbour_list, in_channels, out_channels, joint_num=self.joint_num[-1],
                                      kernel_size=kernel_size, stride=1, padding=padding, padding_mode="reflection",
                                      bias=True))
        sp_conv.append(nn.LeakyReLU(0.2, inplace=True))
        sp_conv.append(nn.InstanceNorm1d(out_channels, affine=False))
        self.sp_conv = nn.Sequential(*sp_conv)
        self.sp_logvar = SkeletonConv(neighbour_list, out_channels, out_channels, joint_num=self.joint_num[-1],
                                      kernel_size=1, stride=1, padding=0, bias=False)
        self.sp_mu = SkeletonConv(neighbour_list, out_channels, out_channels, joint_num=self.joint_num[-1],
                                  kernel_size=1, stride=1, padding=0, bias=False)
        self.sp_channel = out_channels

        # To Global Space
        gl_conv = []
        in_channels = (self.channel_base[-1] + style_dim) * self.joint_num[-1]
        out_channels = self.channel_base[-1] * self.joint_num[-1]* 2
        # Reduce time length, keep skeleton structure
        gl_conv.append(SkeletonConv(neighbour_list, in_channels, out_channels, joint_num=self.joint_num[-1],
                                    kernel_size=kernel_size, stride=2, padding=padding, padding_mode="reflection", bias=True))
        gl_conv.append(nn.AdaptiveAvgPool1d(1))
        gl_conv.append(nn.LeakyReLU(0.2, inplace=True))
        gl_conv.append(SkeletonConv(neighbour_list, out_channels, out_channels, joint_num=self.joint_num[-1],
                                    kernel_size=1, stride=1, padding=0))
        gl_conv.append(nn.LeakyReLU(0.2, inplace=True))
        self.gl_conv = nn.Sequential(*gl_conv)
        self.gl_logvar = SkeletonConv(neighbour_list, out_channels, out_channels, joint_num=self.joint_num[-1],
                                      kernel_size=1, stride=1, padding=0, bias=False)
        self.gl_mu = SkeletonConv(neighbour_list, out_channels, out_channels, joint_num=self.joint_num[-1],
                                  kernel_size=1, stride=1, padding=0, bias=False)
        self.gl_channel = out_channels
        self.action_dim = action_dim
        self.style_dim = style_dim

    def forward(self, input, action_vecs=None, style_vecs=None):
        # print(input.shape)
        for layer in self.mid_layers:
            input = layer(input)
            # print(input.shape)
        B, NC, L = input.shape
        if action_vecs is not None:
            assert action_vecs.shape[0] == input.shape[0]
            # print(input.shape)
            joint_num = self.joint_num[-1]
            sp_input = input.permute(0, 2, 1)
            sp_input = sp_input.view(B, L, joint_num, NC // joint_num)
            action_vecs = action_vecs.view(B, 1, 1, self.action_dim).repeat(1, L, joint_num, 1)
            sp_input = torch.cat([sp_input, action_vecs], dim=-1).view(B, L, -1)
            sp_input = sp_input.permute(0, 2, 1)
            # print(sp_input.shape)
        else:
            sp_input = input

        if style_vecs is not None:
            assert style_vecs.shape[0] == input.shape[0]
            joint_num = self.joint_num[-1]
            gl_input = input.permute(0, 2, 1)
            gl_input = gl_input.view(B, L, joint_num, NC // joint_num)
            style_vecs = style_vecs.view(B, 1, 1, self.style_dim).repeat(1, L, joint_num, 1)
            gl_input = torch.cat([gl_input, style_vecs], dim=-1).view(B, L, -1)
            gl_input = gl_input.permute(0, 2, 1)
            # print()
        else:
            gl_input = input

        sp_code = self.sp_conv(sp_input)
        gl_code = self.gl_conv(gl_input)
        sp_mu = self.sp_mu(sp_code)
        sp_logvar = self.sp_logvar(sp_code)
        gl_mu = self.gl_mu(gl_code)
        gl_logvar = self.gl_logvar(gl_code)
        return sp_mu, sp_logvar, gl_mu.squeeze(), gl_logvar.squeeze()

class AdaptiveInstanceNorm1d(nn.Module):
    def __init__(self, neighbour_list, in_channel, style_channel):
        super().__init__()

        self.norm = nn.InstanceNorm1d(in_channel)
        self.joint_num = len(neighbour_list)
        self.style_net = SkeletonLinear(neighbour_list=neighbour_list,
                                        in_channels=style_channel,
                                        out_channels=in_channel*2)

        self.style_net.bias.data[:in_channel] = 1
        self.style_net.bias.data[in_channel:] = 0

    def forward(self, input, style):
        # input (B, C, seq_len)
        # style (B, style_channel) -> (B, C * 2, 1)
        B, C, L = input.shape
        # print(input.shape)
        # print(style.shape)
        style = self.style_net(style)
        # print(style.shape)
        gamma, beta = style.chunk(2, -1)
        # print(gamma.shape)
        gamma = gamma.view(B, -1).unsqueeze(-1)
        beta = beta.view(B, -1).unsqueeze(-1)

        out = self.norm(input)
        out = gamma * out + beta
        return out

class StyleSkeletonConv1DLayer(nn.Module):
    def __init__(self, neighbour_list, in_channel, out_channel, style_channel, kernel_size, padding,
                 padding_mode="reflection", bias=True):
        super().__init__()
        self.conv = SkeletonConv(neighbour_list, in_channel, out_channel, joint_num=len(neighbour_list),
                                  kernel_size=kernel_size, stride=1, padding=padding, padding_mode=padding_mode, bias=bias)
        self.adain1 = AdaptiveInstanceNorm1d(neighbour_list, out_channel, style_channel)
        self.lrelu1 = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, input, style):
        out = self.conv(input)
        out = self.lrelu1(out)
        out = self.adain1(out, style)
        return out

class Decoder(nn.Module):
    def __init__(self, n_conv, n_up, encoder, kernel_size, sp_channel, gl_channel, target_channel,
                 action_dim=0, style_dim=0, max_channel_per_joint=48):
        super().__init__()
        self.n_conv = n_conv
        self.n_up = n_up
        self.sp_conv = nn.ModuleList()
        self.sp_up_layers = nn.ModuleList()
        self.gl_up_layers = nn.ModuleList()
        self.gl_up_poolers = nn.ModuleList()
        self.sp_up_poolers = nn.ModuleList()
        kernel_size = kernel_size
        padding = (kernel_size - 1) // 2

        # print(encoder.topologies)
        # print(encoder.pooling_lists)
        # print(encoder.joint_num)
        neighbour_list = find_neighbor_joint(encoder.topologies[-1], 2)
        self.joint_num = len(neighbour_list)


        sp_ic = sp_channel
        gl_ic = gl_channel + self.joint_num*style_dim
        for i in range(n_conv):
            sp_oc = min(sp_ic * 2, max_channel_per_joint * self.joint_num)
            if i == 0:
                sp_ic += action_dim*self.joint_num
            self.sp_conv.append(StyleSkeletonConv1DLayer(neighbour_list, sp_ic, sp_oc, gl_ic, kernel_size, padding, bias=True))
            sp_ic = sp_oc

        joint_num = self.joint_num
        gl_ic = gl_channel + self.joint_num*style_dim
        self.style_dim = style_dim
        self.action_dim = action_dim
        for i in range(n_up):
            sp_ic_per = sp_ic // joint_num
            gl_ic_per = gl_ic // joint_num
            if i < len(encoder.pooling_lists):
                pooling_list = encoder.pooling_lists[-i-1]
                joint_num = encoder.joint_num[-i-1]
                self.sp_up_poolers.append(
                    nn.Sequential(SkeletonUnpool(pooling_list=pooling_list, channels_per_edge=sp_ic_per),
                                  nn.Upsample(scale_factor=2, mode="nearest")
                                  ))
                self.gl_up_poolers.append(SkeletonUnpool(pooling_list=pooling_list, channels_per_edge=gl_ic_per))
                topology = encoder.topologies[-i-2]
            else:
                self.sp_up_poolers.append( nn.Upsample(scale_factor=2, mode="nearest"))
            joint_num = len(topology)
            neighbour_list = find_neighbor_joint(topology, 2)
            sp_ic = len(topology) * sp_ic_per
            sp_oc = max(sp_ic // 2, target_channel*2)
            if i < len(encoder.pooling_lists):
                gl_ic = len(topology) * gl_ic_per
                gl_oc = gl_ic // 2
                self.gl_up_layers.append(SkeletonLinear(neighbour_list, gl_ic, gl_oc))
            self.sp_up_layers.append(StyleSkeletonConv1DLayer(neighbour_list, sp_ic, sp_oc, gl_oc, kernel_size, padding))
            gl_ic, sp_ic = gl_oc, sp_oc
        self.out_conv = SkeletonConv(neighbour_list, sp_ic, target_channel, 1, len(topology), stride=1, padding=0, bias=False)
        # self.temporal_up_pooler = nn.Upsample(scale_factor=2, mode="nearest")

    def forward(self, input, style, action_vecs=None, style_vecs=None):
        # print(input.shape)
        # print(style.shape)
        B, NC, L = input.shape
        B, SC = style.shape
        if action_vecs is not None:
            assert action_vecs.shape[0] == input.shape[0]
            joint_num = self.joint_num
            # print(input.shape)
            sp_input = input.permute(0, 2, 1)
            sp_input = sp_input.view(B, L, joint_num, NC // joint_num)
            action_vecs = action_vecs.view(B, 1, 1, self.action_dim).repeat(1, L, joint_num, 1)
            sp_input = torch.cat([sp_input, action_vecs], dim=-1).view(B, L, -1)
            sp_input = sp_input.permute(0, 2, 1)
            # print(sp_input.shape)
        else:
            sp_input = input

        if style_vecs is not None:
            assert style_vecs.shape[0] == input.shape[0]
            joint_num = self.joint_num
            # gl_input = input.permute(0, 2, 1)
            gl_input = style.view(B, joint_num, SC // joint_num)
            style_vecs = style_vecs.view(B, 1, self.style_dim).repeat(1, joint_num, 1)
            # print()
            gl_input = torch.cat([gl_input, style_vecs], dim=-1).view(B, -1)
            # gl_input = gl_input.permute(0, 2, 1)
        else:
            gl_input = style

        for i in range(self.n_conv):
            sp_input = self.sp_conv[i](sp_input, gl_input)

        for i in range(self.n_up):
            if i < len(self.gl_up_layers):
                gl_input = self.gl_up_poolers[i](gl_input)
                gl_input = self.gl_up_layers[i](gl_input)
            sp_input = self.sp_up_poolers[i](sp_input)
            sp_input = self.sp_up_layers[i](sp_input, gl_input)
        sp_out = self.out_conv(sp_input)
        return sp_out